{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Synthetic Images from simulated data\n",
    "\n",
    "## Authors\n",
    "Yi-Hao Chen, Sebastian Heinz, Kelle Cruz, Stephanie T. Douglas\n",
    "\n",
    "## Learning Goals\n",
    "\n",
    "- Assign WCS astrometry to an image using ```astropy.wcs``` \n",
    "- Construct a PSF using ```astropy.modeling.model```\n",
    "- Convolve raw data with PSF using ```astropy.convolution```\n",
    "- Calculate polarization fraction and angle from Stokes I, Q, U data\n",
    "- Overplot quivers on the image\n",
    "\n",
    "## Keywords\n",
    "modeling, convolution, coordinates, WCS, FITS, radio astronomy, matplotlib, colorbar\n",
    "\n",
    "## Summary\n",
    "In this tutorial, we will:\n",
    "\n",
    "[1. Load and examine the FITS file](#1.-Load-and-examine-the-FITS-file)\n",
    "\n",
    "[2. Set up astrometry coordinates](#2.-Set-up-astrometry-coordinates)\n",
    "\n",
    "[3. Prepare a Point Spread Function (PSF)](#3.-Prepare-a-Point-Spread-Function-(PSF))\n",
    "\n",
    ">[3.a How to do this without astropy kernels](#3.a-How-to-do-this-without-astropy-kernels)\n",
    "\n",
    "[4. Convolve image with PSF](#4.-Convolve-image-with-PSF)\n",
    "\n",
    "[5. Convolve Stokes Q and U images](#5.-Convolve-Stokes-Q-and-U-images)\n",
    "\n",
    "[6. Calculate polarization angle and fraction for quiver plot](#6.-Calculate-polarization-angle-and-fraction-for-quiver-plot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from astropy.utils.data import download_file\n",
    "from astropy.io import fits\n",
    "from astropy import units as u\n",
    "from astropy.coordinates import SkyCoord\n",
    "from astropy.wcs import WCS\n",
    "\n",
    "from astropy.convolution import Gaussian2DKernel\n",
    "from astropy.modeling.models import Lorentz1D\n",
    "from astropy.convolution import convolve_fft\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load and examine the FITS file\n",
    "\n",
    "Here we begin with a 2-dimensional data that were stored in FITS format from some simulations. We have Stokes I, Q, and U maps. We we'll first load a FITS file and examine the header."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_i = download_file(\n",
    "    'http://data.astropy.org/tutorials/synthetic-images/synchrotron_i_lobe_0700_150MHz_sm.fits', \n",
    "    cache=True)\n",
    "hdulist = fits.open(file_i)\n",
    "hdulist.info()\n",
    "\n",
    "hdu = hdulist['NN_EMISSIVITY_I_LOBE_150.0MHZ']\n",
    "hdu.header"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see this FITS file, which was created in [yt](https://yt-project.org/), has x and y coordinate in physical units (cm). We want to convert it into sky coordinates. Before we proceed, let's find out the range of the data and plot a histogram. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(hdu.data.max())\n",
    "print(hdu.data.min())\n",
    "np.seterr(divide='ignore') #suppress the warnings raised by taking log10 of data with zeros\n",
    "plt.hist(np.log10(hdu.data.flatten()), range=(-3, 2), bins=100);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we know the range of the data, we can do a visualization with the proper range (```vmin``` and ```vmax```)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(6,12))\n",
    "fig.add_subplot(111)\n",
    "\n",
    "# We plot it in log-scale and add a small number to avoid nan values. \n",
    "plt.imshow(np.log10(hdu.data+1E-3), vmin=-1, vmax=1, origin='lower')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Set up astrometry coordinates\n",
    "\n",
    "From the header, we know that the x and y axes are in centimeter. However, in an observation we usually have RA and Dec. To convert physical units to sky coordinates, we will need to make some assumptions about where the object is located, i.e. the distance to the object and the central RA and Dec. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# distance to the object\n",
    "dist_obj = 200*u.Mpc\n",
    "\n",
    "# We have the RA in hh:mm:ss and DEC in dd:mm:ss format. \n",
    "# We will use Skycoord to convert them into degrees later.\n",
    "ra_obj = '19h59m28.3566s'\n",
    "dec_obj = '+40d44m02.096s'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we convert the pixel scale from cm to degree by dividing the distance to the object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cdelt1 = ((hdu.header['CDELT1']*u.cm/dist_obj.to('cm'))*u.rad).to('deg')\n",
    "cdelt2 = ((hdu.header['CDELT2']*u.cm/dist_obj.to('cm'))*u.rad).to('deg')\n",
    "print(cdelt1, cdelt2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use ```astropy.wcs.WCS``` to prepare a FITS header."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "w = WCS(naxis=2)\n",
    "\n",
    "# reference pixel coordinate\n",
    "w.wcs.crpix = [hdu.data.shape[0]/2,hdu.data.shape[1]/2]\n",
    "\n",
    "# sizes of the pixel in degrees\n",
    "w.wcs.cdelt = [-cdelt1.base, cdelt2.base]\n",
    "\n",
    "# converting ra and dec into degrees\n",
    "c = SkyCoord(ra_obj, dec_obj)\n",
    "w.wcs.crval = [c.ra.deg, c.dec.deg]\n",
    "\n",
    "# the units of the axes are in degrees\n",
    "w.wcs.cunit = ['deg', 'deg']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can convert the WCS coordinate into header and update the hdu."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "wcs_header = w.to_header()\n",
    "hdu.header.update(wcs_header)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at the header. ```CDELT1```, ```CDELT2```, ```CUNIT1```, ```CUNIT2```, ```CRVAL1```, and ```CRVAL2``` are in sky coordinates now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "hdu.header"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "wcs = WCS(hdu.header)\n",
    "\n",
    "fig = plt.figure(figsize=(6,12))\n",
    "fig.add_subplot(111, projection=wcs)\n",
    "plt.imshow(np.log10(hdu.data+1e-3), vmin=-1, vmax=1, origin='lower')\n",
    "plt.xlabel('RA')\n",
    "plt.ylabel('Dec')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we have the sky coordinate for the image!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Prepare a Point Spread Function (PSF)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Simple PSFs are included in ```astropy.convolution.kernel```. We'll use ```astropy.convolution.Gaussian2DKernel``` here.\n",
    "First we need to set the telescope resolution. For a 2D Gaussian, we can calculate sigma in pixels by using our pixel scale keyword ```cdelt2``` from above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# assume our telescope has 1 arcsecond resolution\n",
    "telescope_resolution = 1*u.arcsecond\n",
    "\n",
    "# calculate the sigma in pixels. \n",
    "# since cdelt is in degrees, we use _.to('deg')\n",
    "sigma = telescope_resolution.to('deg')/cdelt2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# By default, the Gaussian kernel will go to 4 sigma\n",
    "# in each direction\n",
    "psf = Gaussian2DKernel(sigma)\n",
    "\n",
    "# let's take a look:\n",
    "plt.imshow(psf.array.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.a How to do this without astropy kernels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Maybe your PSF is more complicated. Here's an alternative way to do this, using a 2D Lorentzian"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# set FWHM and psf grid\n",
    "telescope_resolution = 1*u.arcsecond\n",
    "gamma = telescope_resolution.to('deg')/cdelt2\n",
    "x_grid = np.outer(np.linspace(-gamma*4,gamma*4,int(8*gamma)),np.ones(int(8*gamma)))\n",
    "r_grid = np.sqrt(x_grid**2 + np.transpose(x_grid**2))\n",
    "lorentzian = Lorentz1D(fwhm=2*gamma)\n",
    "\n",
    "# extrude a 2D azimuthally symmetric PSF\n",
    "lorentzian_psf = lorentzian(r_grid)\n",
    "\n",
    "# normalization\n",
    "lorentzian_psf /= np.sum(lorentzian_psf)\n",
    "\n",
    "# let's take a look again:\n",
    "plt.imshow(lorentzian_psf.value, interpolation='none')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Convolve image with PSF"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we use ```astropy.convolution.convolve_fft``` to convolve image. This routine uses fourier transform for faster calculation. Especially since our data is $2^n$ sized, which makes it particually fast. Using a fft, however, causes boundary effects. We'll need to specify how we want to handle the boundary. Here we choose to \"wrap\" the data, which means making the data periodic. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "convolved_image = convolve_fft(hdu.data, psf, boundary='wrap')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Put a psf at the corner of the image\n",
    "delta_x_psf=100 # number of pixels from the edges\n",
    "xmin, xmax = -psf.shape[1]-delta_x_psf, -delta_x_psf\n",
    "ymin, ymax = delta_x_psf, delta_x_psf+psf.shape[0]\n",
    "convolved_image[xmin:xmax, ymin:ymax] = psf.array/psf.array.max()*10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Now let's take a look at the convolved image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "wcs = WCS(hdu.header)\n",
    "fig = plt.figure(figsize=(8,12))\n",
    "i_plot = fig.add_subplot(111, projection=wcs)\n",
    "plt.imshow(np.log10(convolved_image+1e-3), vmin=-1, vmax=1.0, origin='lower')#, cmap=plt.cm.viridis)\n",
    "plt.xlabel('RA')\n",
    "plt.ylabel('Dec')\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Convolve Stokes Q and U images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "hdulist.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "file_q = download_file(\n",
    "    'http://data.astropy.org/tutorials/synthetic-images/synchrotron_q_lobe_0700_150MHz_sm.fits', \n",
    "    cache=True)\n",
    "hdulist = fits.open(file_q)\n",
    "hdu_q = hdulist['NN_EMISSIVITY_Q_LOBE_150.0MHZ']\n",
    "\n",
    "file_u = download_file(\n",
    "    'http://data.astropy.org/tutorials/synthetic-images/synchrotron_u_lobe_0700_150MHz_sm.fits', \n",
    "    cache=True)\n",
    "hdulist = fits.open(file_u)\n",
    "hdu_u = hdulist['NN_EMISSIVITY_U_LOBE_150.0MHZ']\n",
    "\n",
    "# Update the header with the wcs_header we created earlier\n",
    "hdu_q.header.update(wcs_header)\n",
    "hdu_u.header.update(wcs_header)\n",
    "\n",
    "# Convolve the images with the the psf\n",
    "convolved_image_q = convolve_fft(hdu_q.data, psf, boundary='wrap')\n",
    "convolved_image_u = convolve_fft(hdu_u.data, psf, boundary='wrap')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's plot the Q and U images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "wcs = WCS(hdu.header)\n",
    "fig = plt.figure(figsize=(16,12))\n",
    "fig.add_subplot(121, projection=wcs)\n",
    "plt.imshow(convolved_image_q, cmap='seismic', vmin=-0.5, vmax=0.5, origin='lower')#, cmap=plt.cm.viridis)\n",
    "plt.xlabel('RA')\n",
    "plt.ylabel('Dec')\n",
    "plt.colorbar()\n",
    "\n",
    "fig.add_subplot(122, projection=wcs)\n",
    "plt.imshow(convolved_image_u, cmap='seismic', vmin=-0.5, vmax=0.5, origin='lower')#, cmap=plt.cm.viridis)\n",
    "\n",
    "plt.xlabel('RA')\n",
    "plt.ylabel('Dec')\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Calculate polarization angle and fraction for quiver plot "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that rotating Stokes Q and I maps requires changing signs of both. Here we assume that the Stokes q and u maps were calculated defining the y/declination axis as vertical, such that Q is positive for polarization vectors along the x/right-ascention axis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# First, we plot the background image\n",
    "fig = plt.figure(figsize=(8,16))\n",
    "i_plot = fig.add_subplot(111, projection=wcs)\n",
    "i_plot.imshow(np.log10(convolved_image+1e-3), vmin=-1, vmax=1, origin='lower')\n",
    "\n",
    "# ranges of the axis\n",
    "xx0, xx1 = i_plot.get_xlim()\n",
    "yy0, yy1 = i_plot.get_ylim()\n",
    "\n",
    "# binning factor\n",
    "factor = [64, 66]\n",
    "\n",
    "# re-binned number of points in each axis\n",
    "nx_new = convolved_image.shape[1] // factor[0]\n",
    "ny_new = convolved_image.shape[0] // factor[1]\n",
    "\n",
    "# These are the positions of the quivers\n",
    "X,Y = np.meshgrid(np.linspace(xx0,xx1,nx_new,endpoint=True),\n",
    "                  np.linspace(yy0,yy1,ny_new,endpoint=True))\n",
    "\n",
    "# bin the data\n",
    "I_bin = convolved_image.reshape(nx_new, factor[0], ny_new, factor[1]).sum(3).sum(1)\n",
    "Q_bin = convolved_image_q.reshape(nx_new, factor[0], ny_new, factor[1]).sum(3).sum(1)\n",
    "U_bin = convolved_image_u.reshape(nx_new, factor[0], ny_new, factor[1]).sum(3).sum(1)\n",
    "\n",
    "# polarization angle\n",
    "psi = 0.5*np.arctan2(U_bin, Q_bin)\n",
    "\n",
    "# polarization fraction\n",
    "frac = np.sqrt(Q_bin**2+U_bin**2)/I_bin\n",
    "\n",
    "# mask for low signal area\n",
    "mask = I_bin < 0.1\n",
    "\n",
    "frac[mask] = 0\n",
    "psi[mask] = 0\n",
    "\n",
    "pixX = frac*np.cos(psi) # X-vector \n",
    "pixY = frac*np.sin(psi) # Y-vector\n",
    "\n",
    "# keyword arguments for quiverplots\n",
    "quiveropts = dict(headlength=0, headwidth=1, pivot='middle')\n",
    "i_plot.quiver(X, Y, pixX, pixY, scale=8, **quiveropts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercise"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convert the units of the data from Jy/arcsec^2 to Jy/beam"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "The intensity of the data is given in unit of Jy/arcsec^2. Observational data usually have the intensity unit in Jy/beam. Assuming a beam size or take the psf we created earlier, you can convert the data into Jy/beam."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
